import torch
import torch.distributions
from torchvision import datasets, transforms
from torch.utils.data import Sampler

from utils.datasets.tinyImages import _load_cifar_exclusion_idcs
from utils.datasets.paths import get_CIFAR10_path, get_CIFAR100_path
from utils.datasets.openimages import OpenImages
import numpy as np

from utils.datasets.imagenet_subsets import get_ImageNet100, get_ImageNet100_labels
from utils.datasets.imagenet_augmentation import get_imageNet_augmentation
from .imagenet100_train_val import ImageNet100TrainValidationSplit
from utils.datasets.paths import get_openimages_path, get_imagenet_path

from utils.datasets.tinyImages import _load_tiny_image, _preload_tiny_images
from .loading_utils import load_teacher_data

def get_openImages_partition(dataset_classifications_path, teacher_model, samples_per_class, imagenet100TrainValSplit,
                             class_tpr_min=None, od_exclusion_threshold=None, calibrate_temperature=False,
                             id_class_balanced=True, verbose_exclude=False, soft_labels=True, batch_size=100,
                             augm_type='default', size=224, num_workers=8, exclude_imagenet100=False,
                             id_config_dict=None, od_config_dict=None, ssl_config=None):

    model_confidences, _,class_thresholds, temperature = load_teacher_data(dataset_classifications_path, teacher_model,
                                                                         class_tpr_min=class_tpr_min,
                                                                         od_exclusion_threshold=od_exclusion_threshold,
                                                                         calibrate_temperature=calibrate_temperature,
                                                                         ssl_config=ssl_config)

    augm_config = {}
    transform = get_imageNet_augmentation(augm_type, out_size=size, config_dict=augm_config)

    top_dataset = ImageNet100OpenImagesTopKPartition(model_confidences, samples_per_class=samples_per_class, transform=transform,
                                                     imageNet100TrainValSplit=imagenet100TrainValSplit,
                                                     min_conf=class_thresholds, temperature=temperature,
                                                     soft_labels=soft_labels, exclude_imageNet100=exclude_imagenet100)
    top_loader = torch.utils.data.DataLoader(top_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)

    top_k_indices = top_dataset.get_used_semi_indices(verbose_exclude)
    bottom_dataset = CifarTinyImageBottomKPartition(model_confidences, top_k_indices, transform_base=transform,
                                                    temperature=temperature, soft_labels=soft_labels,
                                                    exclude_cifar=exclude_imagenet100, exclude_cifar10_1=exclude_cifar10_1)

    bottom_loader = torch.utils.data.DataLoader(bottom_dataset, shuffle=True, batch_size=batch_size, num_workers=1)

    if id_config_dict is not None:
        id_config_dict['Dataset'] ='Cifar-SSL'
        id_config_dict['Train validation split'] = imagenet100TrainValSplit
        id_config_dict['Batch out_size'] = batch_size
        id_config_dict['Samples per class'] = samples_per_class
        id_config_dict['Soft labels'] = soft_labels
        id_config_dict['Class balanced'] = id_class_balanced
        id_config_dict['Exclude CIFAR'] = True if dataset in ['CIFAR10', 'CIFAR100'] else False
        id_config_dict['Augmentation'] = augm_config

    if od_config_dict is not None:
        od_config_dict['Dataset'] = 'TinyImagesPartition'
        od_config_dict['Batch out_size'] = batch_size
        od_config_dict['Verbose exclude'] = verbose_exclude
        od_config_dict['Exclude CIFAR'] = True if dataset in ['CIFAR10', 'CIFAR100'] else False
        od_config_dict['Augmentation'] = augm_config

    return top_loader, bottom_loader


def get_tiny_partition_all_sampling(dataset_classifications_path, teacher_model, dataset, cifarTrainValSplit,
                                    unlabeled_ratio,
                                    class_tpr_min=None, od_exclusion_threshold=None, calibrate_temperature=False,
                                    soft_labels=True, batch_size=100,
                                    augm_type='default', cutout_window=16, size=32, num_workers=8, id_config_dict=None,
                                    exclude_cifar=False, exclude_cifar10_1=False,
                                    od_config_dict=None, ssl_config=None):
    # if num_workers > 1:
    #     raise ValueError('Bug in the current multithreaded tinyimages implementation')

    model_logits, _, class_thresholds, temperature = load_teacher_data(dataset_classifications_path, teacher_model,
                                                                    class_tpr_min=class_tpr_min,
                                                                    od_exclusion_threshold=od_exclusion_threshold,
                                                                    calibrate_temperature=calibrate_temperature,
                                                                    ssl_config=ssl_config)

    augm_config = {}
    transform = get_cifar10_augmentation(augm_type, cutout_window, out_size=size, config_dict=augm_config)

    top_dataset = ImageNet100OpenImagesTopKPartition(model_logits, samples_per_class=1e9, transform=transform,
                                                     imageNet100TrainValSplit=cifarTrainValSplit,
                                                     min_conf=class_thresholds, temperature=temperature,
                                                     soft_labels=soft_labels, exclude_imageNet100=exclude_cifar)

    sampler = AllValidSampler(top_dataset, unlabeled_ratio=unlabeled_ratio)

    top_loader = torch.utils.data.DataLoader(top_dataset, sampler=sampler,
                                             batch_size=batch_size, num_workers=num_workers)

    top_k_indices = top_dataset.get_used_semi_indices(False)
    bottom_dataset = CifarTinyImageBottomKPartition(model_logits, top_k_indices, transform_base=transform,
                                                    temperature=temperature, soft_labels=soft_labels,
                                                    exclude_cifar=exclude_cifar, exclude_cifar10_1=exclude_cifar10_1)
    bottom_loader = torch.utils.data.DataLoader(bottom_dataset, shuffle=True, batch_size=batch_size, num_workers=1)

    if id_config_dict is not None:
        id_config_dict['Dataset'] = 'ImageNet100-SSL'
        id_config_dict['Train validation split'] = cifarTrainValSplit
        id_config_dict['Unlabeled ratio'] = unlabeled_ratio
        id_config_dict['Batch out_size'] = batch_size
        id_config_dict['Soft labels'] = soft_labels
        id_config_dict['Exclude CIFAR'] = True if dataset in ['CIFAR10', 'CIFAR100'] else False
        id_config_dict['Augmentation'] = augm_config

    if od_config_dict is not None:
        od_config_dict['Dataset'] = 'TinyImagesPartition'
        od_config_dict['Batch out_size'] = batch_size
        od_config_dict['Exclude CIFAR'] = True if dataset in ['CIFAR10', 'CIFAR100'] else False
        od_config_dict['Augmentation'] = augm_config

    return top_loader, bottom_loader



class CifarTinyImageBottomKPartition(torch.utils.data.Dataset):
    def __init__(self, model_logits, top_k_indices, transform_base, temperature=1, soft_labels=True,
                 exclude_cifar=False, exclude_cifar10_1=False):
        self.data_location = get_tiny_images_files(False)
        self.fileID = open(self.data_location, "rb")
        self.soft_labels = soft_labels
        self.num_classes = model_logits.shape[1]
        self.temperature = temperature

        transform = transforms.Compose([
            transforms.ToPILImage(),
            transform_base])

        self.transform = transform

        self.model_logits = model_logits

        #in_use_indices [i] holds all valid indices for i-th confidence interval
        self.valid_indices = []

        non_cifar = torch.ones(self.model_logits.shape[0], dtype=torch.bool)

        cifar_idxs = _load_cifar_exclusion_idcs(exclude_cifar, exclude_cifar10_1)
        cifar_idxs = torch.LongTensor(cifar_idxs)
        non_cifar[cifar_idxs] = 0

        valid_bool_indices  = torch.ones(self.model_logits.shape[0], dtype=torch.bool)
        valid_bool_indices[top_k_indices] = 0
        valid_bool_indices = valid_bool_indices & non_cifar
        self.valid_indices = torch.nonzero(valid_bool_indices).squeeze()

        self.length = len(self.valid_indices)

        print(f'Exclude Cifar {exclude_cifar} - Samples {self.length} - Temperature {self.temperature}')

    def __getitem__(self, index):
        valid_index = self.valid_indices[index]
        img = _load_tiny_image(valid_index, self.fileID)

        if self.transform is not None:
            img = self.transform(img)

        if self.soft_labels:
            model_prediction = torch.softmax(self.model_logits[valid_index, :] / self.temperature, dim=0)
        else:
            model_prediction = (1./self.num_classes) * torch.ones(self.num_classes)

        return img, model_prediction

    def __len__(self):
        return self.length

class ImageNet100OpenImagesTopKPartition(torch.utils.data.Dataset):
    def __init__(self, model_logits, samples_per_class, transform, imageNet100TrainValSplit, min_conf, temperature=1,
                 soft_labels=True, exclude_imageNet100=False):
        self.samples_per_class = samples_per_class
        self.soft_labels = soft_labels
        self.temperature = temperature
        self.transform = transform
        self.model_logits = model_logits

        predicted_max_conf, predicted_class = torch.max(torch.softmax(self.model_logits,dim=1), dim=1)

        self.num_classes = self.model_logits.shape[1]

        class_labels = get_ImageNet100_labels()
        path = get_imagenet_path()
        if imageNet100TrainValSplit:
            self.imagenet100_dataset = ImageNet100TrainValidationSplit(path, train=True, transform=None)
        else:
            raise NotImplementedError()

        self.num_classes = 100

        openimages_path = get_openimages_path()
        self.openimages_dataset = OpenImages(openimages_path, 'train', transform=None)

        non_imagenet100 = torch.ones(self.model_logits.shape[0], dtype=torch.bool)

        exclude_idcs = []
        if exclude_imageNet100:
            with open('openImages_imageNet100_duplicates.txt', 'r') as idxs:
                for idx in idxs:
                    exclude_idcs.append(int(idx))

        imagnet100_duplicate_idxs = torch.LongTensor(exclude_idcs)
        non_imagenet100[imagnet100_duplicate_idxs] = 0

        self.num_imagenet100_samples = len(self.imagenet100_dataset)

        self.imagenet100_class_idcs = []
        self.imagenet100_samples_per_class = []
        targets_tensor = torch.LongTensor(self.imagenet100_dataset.targets)
        for i in range(self.num_classes):
            imagenet100_i = torch.nonzero(targets_tensor == i, as_tuple=False).squeeze()
            self.imagenet100_class_idcs.append(imagenet100_i)
            self.imagenet100_samples_per_class.append( len(imagenet100_i) )

        self.in_use_indices = []
        self.valid_indices = []
        self.class_semi_counts = []

        for i in range(self.num_classes):
            min_conf_flag = predicted_max_conf >= min_conf[i]
            non_cifar_correct_class_bool_idcs = (predicted_class == i) & non_imagenet100 & min_conf_flag

            non_in_correct_class_linear_idcs = torch.nonzero(non_cifar_correct_class_bool_idcs).squeeze()
            non_in_correct_class_confidences = predicted_max_conf[non_cifar_correct_class_bool_idcs]
            non_in_correct_class_sort_idcs = torch.argsort(non_in_correct_class_confidences, descending=True)

            num_samples_i = int( min( samples_per_class, len(non_in_correct_class_linear_idcs) ))
            class_i_idcs = non_in_correct_class_linear_idcs[non_in_correct_class_sort_idcs[: num_samples_i]]

            self.valid_indices.append(non_in_correct_class_linear_idcs)

            self.in_use_indices.append(class_i_idcs)
            self.class_semi_counts.append(len(class_i_idcs))

            if num_samples_i < samples_per_class:
                print(f'Incomplete class {class_labels[i]} - Target count: {samples_per_class} - Found samples {len(class_i_idcs)}')

        min_semi_sampels_per_class = min(self.class_semi_counts)
        max_semi_samples_per_class = max(self.class_semi_counts)


        self.num_semi_samples = 0
        self.length = self.num_imagenet100_samples
        for i in range(self.num_classes):
            self.num_semi_samples += self.class_semi_counts[i]
            self.length +=  self.class_semi_counts[i]

        print(f'Top K -  Temperature {self.temperature} - Soft labels {soft_labels} - Exclude Cifar {exclude_imageNet100}'
              f'  -  Target Samples per class { self.samples_per_class} - Cifar Samples {self.num_imagenet100_samples}')
        print(f'Min Semi Samples {min_semi_sampels_per_class} - Max Semi samples {max_semi_samples_per_class}'
              f' - Total semi samples {self.num_semi_samples} - Total length {self.length}')

        print(f'Preloading images')
        self.class_data = []
        for class_idx in range(self.num_classes):
            for sample_i in self.imagenet100_class_idcs:
                self.class_data.append( self.imagenet100_dataset[sample_i][0] )
            for sample_i in self.valid_indices[class_idx]:
                self.class_data.append( self.openimages_dataset[sample_i][0] )


    #if verbose exclude, include all indices that fulfill the conf requirement but that are outside of the top-k range
    def get_used_semi_indices(self, verbose_exclude=False):
        if verbose_exclude:
            return torch.cat(self.valid_indices)
        else:
            return torch.cat(self.in_use_indices)

    def _load_cifar_image(self, cifar_idx):
        img, label = self.imagenet100_dataset[cifar_idx]
        if self.soft_labels:
            one_hot_label = torch.zeros(self.num_classes)
            one_hot_label[label] = 1.0
            return img, one_hot_label
        else:
            return img, label

    def _load_tiny_image(self, class_idx, tiny_lin_idx):
        valid_index = self.in_use_indices[class_idx][tiny_lin_idx]
        if self.preload:
            img = self.class_data[class_idx][tiny_lin_idx, :]
        else:
            img = _load_tiny_image(valid_index, self.fileID)

        if self.transform is not None:
            img = self.transform(img)

        if self.soft_labels:
            label = torch.softmax(self.model_logits[valid_index, :] / self.temperature, dim=0)
        else:
            label = torch.argmax(self.model_logits[valid_index, :]).item()
        return img, label

    def __getitem__(self, index):
        if index < self.num_imagenet100_samples:
            class_idx = int(np.floor(index / self.cifar_per_class))
            sample_idx = int(np.floor(index % self.cifar_per_class))
            cifar_class_idx = self.imagenet100_class_idcs[class_idx][sample_idx]
            return self._load_cifar_image(cifar_class_idx)
        else:
            index_semi = index - self.num_imagenet100_samples
            cumulative_idx = 0
            for i in range(self.num_classes):
                next_cumulative = cumulative_idx + self.class_semi_counts[i]
                if index_semi < next_cumulative:
                    class_idx = i
                    sample_idx = index_semi - cumulative_idx
                    break
                cumulative_idx = next_cumulative

            return self._load_tiny_image(class_idx, sample_idx)

    def __len__(self):
        return self.length

class AllValidSampler(Sampler):
    def __init__(self, cifar_top_k_partition, unlabeled_ratio):
        super().__init__(None)
        self.unlabeled_ratio = unlabeled_ratio
        self.unlabeled_per_class = cifar_top_k_partition.class_semi_counts.copy()
        self.num_classes = len(self.unlabeled_per_class)
        self.max_unlabeled_per_class = max(self.unlabeled_per_class)

        self.cifar_samples = cifar_top_k_partition.num_cifar_samples
        self.unlabeled_samples = self.cifar_samples * unlabeled_ratio
        self.unlabeled_samples_per_class = int(self.unlabeled_samples / self.num_classes)
        self.length = self.cifar_samples + self.unlabeled_samples

        print(f'Top-K Sampler: Ratio {unlabeled_ratio} - Labeled {self.cifar_samples}'
              f' - Unlabeled {self.unlabeled_samples} - Total {self.length}')


    def __iter__(self):
        cifar_indices = torch.randperm(self.cifar_samples).tolist()
        unlabeled_class_indices = []

        cls_start_range = self.cifar_samples #in the dataset, unlabeled start behind gt cifar samples

        for cls in range(self.num_classes):
            cls_end_range = cls_start_range + self.unlabeled_per_class[cls]

            cls_range = torch.arange(cls_start_range, cls_end_range, dtype=torch.long)
            cls_indices = torch.zeros(self.max_unlabeled_per_class, dtype=torch.long)
            cls_indices[:self.unlabeled_per_class[cls]] = cls_range[torch.randperm(self.unlabeled_per_class[cls])]

            num_cls_indices = self.unlabeled_per_class[cls]
            while num_cls_indices < self.max_unlabeled_per_class:
                missing_idcs =  self.max_unlabeled_per_class - num_cls_indices
                num_fill_idcs = min(missing_idcs, self.unlabeled_per_class[cls])

                next_num_cls_indices = num_cls_indices + num_fill_idcs
                cls_indices[num_cls_indices:next_num_cls_indices] = cls_range[torch.randperm(self.unlabeled_per_class[cls])[:num_fill_idcs]]

                num_cls_indices = next_num_cls_indices

            unlabeled_class_indices.append(cls_indices)
            cls_start_range = cls_end_range

        unlabeled_class_order = torch.zeros(self.unlabeled_samples, dtype=torch.long)
        for cls in range(self.num_classes):
            unlabeled_class_order[cls*self.unlabeled_samples_per_class:(cls+1)*self.unlabeled_samples_per_class] = cls

        unlabeled_class_order = unlabeled_class_order[torch.randperm(self.unlabeled_samples)]

        idcs = torch.zeros(self.length, dtype=torch.long)
        num_labeled_included = 0
        unlabeled_class_indices_indices = torch.zeros(self.num_classes, dtype=torch.long)
        for i in range(self.length):
            if i % ( 1 + self.unlabeled_ratio) == 0:
                #CIFAR samples are the first in the unbalanced dataset
                idcs[i] = cifar_indices[int(i / ( 1 + self.unlabeled_ratio))]
                num_labeled_included += 1
            else:
                #unlabeled samples:
                unlabeled_idx = i - num_labeled_included
                unlabeled_class = unlabeled_class_order[unlabeled_idx]

                idcs[i] = unlabeled_class_indices[unlabeled_class][unlabeled_class_indices_indices[unlabeled_class]]
                unlabeled_class_indices_indices[unlabeled_class] += 1

        return iter(idcs)

    def __len__(self):
        return self.length

